library(pacman)
p_load(here, fst, data.table, ggplot2)

# Reading original production prediction for one plant 
sample(plant_gen_fp, 1)
orig_dt = read.fst(
  path = here('Data/electricity-generation/plant-model-fit/plant-gen-dt/plant-gen-dt-408.fst'),
  as.data.table = TRUE
)
# New prediction function on one plant 
source('R/02-power-plants/predict-production.r')
all_load_dt = 
    read.fst(
      path = here(
        "Data/electricity-generation/clean-load/nerc-load-dt-oge.fst"
      ),
      as.data.table = TRUE
    )[!(nerc_adj %in% c('AK','ASCC','HI')),
      .(nerc_adj = str_to_lower(nerc_adj), datetime_utc, excess_load)
    ]
all_load_dt[,    
  ic := fcase(
    nerc_adj == 'tre', 'Texas',
    nerc_adj %in% c('wecc','cal'), 'West',
    default = 'East'
  ) |> factor(levels = c('West','Texas','East'))]
load_dt = prep_load_table(all_load_dt, id_var = 'datetime_utc')
# Prepping data
model_tables = prep_model_table()
plant_dt = prep_plant_table(model_tables$model_dt)
epsilon_dt = generate_epsilons(plant_dt)  

epsilon_dt = orig_dt[,.(
  plant_id_eia = as.character(plant_id_eia), 
  epsilon_id = 1, 
  epsilon
)][2,]


new_dt = predict_production_plant(
    plant_id_eia = '408', 
    load_dt = load_dt, 
    plant_dt = plant_dt, 
    model_dt_long = model_tables$model_dt_long, 
    epsilon_dt = epsilon_dt
)
new_dt[,plant_id_eia := as.integer(plant_id_eia)]
setnames(new_dt, 'id','datetime_utc')
# Comparing!
comp_dt = 
  merge(
    orig_dt, 
    new_dt, 
    by = c('plant_id_eia','datetime_utc')
  )

# ok so the raw data is correct--must have to do with epsilon then 
ggplot(comp_dt, aes(x = fit_net_generation_mwh_raw,y = fit_net_gen_raw)) + 
geom_point(alpha = 0.5) + 
geom_abline(slope = 1, intercept = 0)

ggplot(comp_dt, aes(x = fit_net_generation_mwh.x,y = fit_net_generation_mwh.y)) + 
geom_smooth() + 
geom_abline(slope = 1, intercept = 0)


# Trying a new way to use same epislons as before 
pred_prod_plant = function(
    plant_id_eia_in, 
    load_dt, 
    plant_dt, 
    model_dt_long
){
  print(plant_id_eia_in)
  # First merging together
  prod_dt = 
    merge(
      load_dt,
      model_dt_long[plant_id_eia == plant_id_eia_in],
      by = c('nerc_adj','square'),
      allow.cartesian = TRUE
    )
  # Returning empty table for intercept only models
  if(nrow(prod_dt) == 0){
    return(
      data.table(
        plant_id_eia = character(),
        quantile = numeric(),
        tot_eload = numeric(),
        fit_net_generation_mwh = numeric(),
        fit_emissions_tons_co2e = numeric(),
        fit_emissions_tons_so2 = numeric(),
        fit_emissions_tons_nox = numeric(),
        fit_emissions_tons_pm25 = numeric(),
        damages = numeric(),
        md_cal = numeric(),
        md_mro = numeric(),
        md_npcc = numeric(),
        md_rfc = numeric(),
        md_serc = numeric(),
        md_tre = numeric(),
        md_wecc = numeric()
      )
    )
  }
  # First summing coefs 
  coef_prod_dt = 
    prod_dt[,.(
      plant_id_eia, 
      id,
      fit_net_gen_raw = fcase(
        est_extremum <= 0 | is.na(est_extremum), coef*excess_load, 
        est_extremum > 0 & excess_load <= est_extremum, coef*excess_load,
        est_extremum > 0 & excess_load > est_extremum, coef*est_extremum,
        default = NA
      )
    )] |>
    gby(plant_id_eia, id) |>
    fsum() |>
    setkey(plant_id_eia, id)    
  # Adding epsilons from original table 
  orig_eps_dt = read.fst(
    path = here(
      'Data/electricity-generation/plant-model-fit/plant-gen-dt',
      paste0('plant-gen-dt-',plant_id_eia_in,'.fst')
    ),
    as.data.table = TRUE
  )[,.(
    plant_id_eia = as.character(plant_id_eia), 
    id = datetime_utc,
    epsilon
  )]
  orig_eps_dt = 
    rbind(
      orig_eps_dt, 
      orig_eps_dt[,.(
        plant_id_eia, 
        id = id + years(1),
        epsilon
      )]
    )
  # Merging to intercept and log scale table 
  gen_dt = 
    merge(
      plant_dt,
      coef_prod_dt,
      by = 'plant_id_eia'
    ) |>
    merge(
      orig_eps_dt,
      by = c('plant_id_eia','id')
    )
  # Now calculating raw fitted value (censored)
  gen_dt[,
    fit_net_generation_mwh := fcase(
      fit_net_gen_raw + epsilon >= 0 
      & fit_net_gen_raw + epsilon <= namepcap, 
      fit_net_gen_raw + epsilon,
      fit_net_gen_raw + epsilon < 0, 0,
      fit_net_gen_raw + epsilon > namepcap, namepcap
    )
  ]
  # Removing columns we don't need anymore 
  gen_dt[,c('namepcap','log_scale','intercept','epsilon') := NULL]
  # Now calculating emissions and damages
  gen_dt[,
    fit_emissions_tons_co2e := fifelse(
      fit_net_generation_mwh >= net_generation_mwh_split,
      (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_co2e
      + net_generation_mwh_split*emissions_rate_low_co2e,
      fit_net_generation_mwh*emissions_rate_low_co2e
    )
  ]
  gen_dt[,
    fit_emissions_tons_so2 := fifelse(
      fit_net_generation_mwh >= net_generation_mwh_split,
      (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_so2
      + net_generation_mwh_split*emissions_rate_low_so2,
      fit_net_generation_mwh*emissions_rate_low_so2
    )
  ]
  gen_dt[,
    fit_emissions_tons_nox := fifelse(
      fit_net_generation_mwh >= net_generation_mwh_split,
      (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_nox
      + net_generation_mwh_split*emissions_rate_low_nox,
      fit_net_generation_mwh*emissions_rate_low_nox
    )
  ]
  gen_dt[,
    fit_emissions_tons_pm25 := fifelse(
      fit_net_generation_mwh >= net_generation_mwh_split,
      (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_pm25
      + net_generation_mwh_split*emissions_rate_low_pm25,
      fit_net_generation_mwh*emissions_rate_low_pm25
    )
  ]
  gen_dt[,
    damages := fifelse(
      fit_net_generation_mwh >= net_generation_mwh_split,
      (fit_net_generation_mwh-net_generation_mwh_split)*md_per_mwh_high
      + net_generation_mwh_split*md_per_mwh_low,
      fit_net_generation_mwh*md_per_mwh_low
    )
  ]
  return(gen_dt)
}


predict_production_region = function(
  nerc_adj_in, load_dt, plant_dt, model_dt_long
){
  # Predicting production
  fit_prod_dt = 
    map(
      plant_dt[nerc_adj == nerc_adj_in]$plant_id_eia, 
      pred_prod_plant, 
      load_dt = load_dt, 
      plant_dt = plant_dt, 
      model_dt_long = model_tables$model_dt_long
    ) |> 
    rbindlist() |>
    setnames('id','datetime_utc')
  # Load actual production data 
  oge_elec_gen_dt=
    read.fst(
      path = here(paste0(
        "Data/electricity-generation/elec-gen-dt/elec-gen-dt-",
        tolower(nerc_adj_in),
        ".fst"
      )),
      as.data.table = TRUE
    )[,.(
      plant_id_eia = as.character(plant_id_eia), 
      year, 
      datetime_utc, 
      net_generation_mwh,
      co2e_mass_lb_for_electricity, 
      nox_mass_lb_for_electricity, 
      so2_mass_lb_for_electricity
    )] |> 
    setkey(plant_id_eia, datetime_utc) 
  # Merging together and saving the results 
  out_dt = 
    merge(
      fit_prod_dt, 
      oge_elec_gen_dt, 
      by = c('plant_id_eia','datetime_utc')
    )
  write.fst(
    x = out_dt, 
    path = here(paste0(
      'Data/electricity-generation/plant-model-fit/predict-production/one-epsilon/', 
      str_to_lower(nerc_adj_in),
      '.fst'
    ))
  )
  rm(out_dt, fit_prod_dt, oge_elec_gen_dt)
  gc()
}

  # Prepping data
  all_load_dt = 
    read.fst(
      path = here(
        "Data/electricity-generation/clean-load/nerc-load-dt-oge.fst"
      ),
      as.data.table = TRUE
    )[!(nerc_adj %in% c('AK','ASCC','HI')),
      .(nerc_adj = str_to_lower(nerc_adj), datetime_utc, excess_load)
    ]
  all_load_dt[,    
    ic := fcase(
      nerc_adj == 'tre', 'Texas',
      nerc_adj %in% c('wecc','cal'), 'West',
      default = 'East'
    ) |> factor(levels = c('West','Texas','East'))]
  load_dt = prep_load_table(all_load_dt, id_var = 'datetime_utc')
  model_tables = prep_model_table()
  plant_dt = prep_plant_table(model_tables$model_dt)
  nerc_fit_path='Data/electricity-generation/plant-model-fit/predict-production/one-epsilon'
  nerc_already_run = list.files(here(nerc_fit_path)) |> 
    str_extract('cal|wecc|tre|serc|mro|npcc|rfc')|>
    str_to_upper() |>
    na.omit()
  # Running for each region 
  nerc_to_run = plant_dt[!(nerc_adj %in% nerc_already_run)]$nerc_adj |> unique()
  map(
    nerc_to_run,
    predict_production_region,
    load_dt = load_dt, 
    plant_dt = plant_dt, 
    model_dt_long = model_tables$model_dt_long
  )


    ggplot(
      ic_gen_dt,
      aes(
        x = net_generation_mwh/1e3, 
        y = fit_net_generation_mwh/1e3,
        color = as.character(year)
      )
    ) + 
    geom_point(alpha = 0.01, color = 'gray30', shape = 19) +
    geom_smooth() +
    geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
    facet_wrap(
      ~ic,
      scales = "free"
    )+ 
    labs(
      x = "Actual Production (GWh)",
      y = "Predicted Production (GWh)"
    ) + 
    theme_minimal() +
    theme(legend.position = 'bottom')


# Checking one more plant...
new_dt = 
  pred_prod_plant(
    '994',
    load_dt, 
    plant_dt, 
    model_tables$model_dt_long
  ) |> setnames(
    'id','datetime_utc'
  )
orig_dt = read.fst(
  path = here(
    'Data/electricity-generation/plant-model-fit/plant-gen-dt/plant-gen-dt-994.fst'
    ),
  as.data.table = TRUE
)[,plant_id_eia := as.character(plant_id_eia)]

comp_dt = 
  merge(
    new_dt,
    orig_dt,
    by = c('plant_id_eia','datetime_utc')
  )


ggplot(comp_dt, aes(x = fit_net_generation_mwh.x,y = fit_net_generation_mwh.y)) + 
geom_point() + 
geom_abline(slope = 1, intercept = 0)


# Now we can compare the region-fuel 
comp_dt = 
  merge(
    elec_gen_dt, 
    nerc_fuel_gen_dt, 
    by = c('nerc_adj','fuel_category','datetime_utc')
  )

ggplot(comp_dt[nerc_adj == 'CAL'], aes(x = tot_fit_net_generation_mwh, y = fit_net_generation_mwh)) + 
geom_point() + 
geom_abline(slope = 1, intercept = 0)


# Lets get all of the plants in CAL and compare them 
nerc_plant_gen_fp = 
  plant_gen_fp[
    str_extract(plant_gen_fp, '(?<=-)\\d*(?=\\.fst)') |> as.integer()
    %in% plant_info_dt[nerc_adj == 'CAL']$plant_id_eia
  ]

new_cal_dt = 
  read.fst(
    path = here(
      nerc_fit_path, 
      paste0(str_to_lower('cal'),'.fst')
    ),
    as.data.table = TRUE
  )
orig_cal_dt =  
  map_dfr(
    nerc_plant_gen_fp, 
    read.fst,
    as.data.table = TRUE
  )[,plant_id_eia := as.character(plant_id_eia)]

comp_dt = 
  merge(
    new_cal_dt, 
    orig_cal_dt, 
    by = c('plant_id_eia','datetime_utc')
  )

ggplot(comp_dt[plant_id_eia == '63699'], aes(x = fit_net_generation_mwh.x,y = fit_net_generation_mwh.y)) + 
geom_point() + 
geom_abline(slope = 1, intercept = 0)

ggplot(comp_dt, aes(x = fit_net_generation_mwh.x - fit_net_generation_mwh.y)) + 
geom_density()

# ok so ~9% are different 

comp_dt[
  ,
  fmean(abs(fit_net_generation_mwh.x - fit_net_generation_mwh.y)> 0.00001), 
  keyby = plant_id_eia
][ V1 > 0]


ggplot(comp_dt[plant_id_eia == '63699'], aes(x = fit_net_generation_mwh.x,y = fit_net_generation_mwh.y)) + 
geom_point() + 
geom_abline(slope = 1, intercept = 0)

# Looks like a square term is missing in one of these
ggplot(comp_dt[plant_id_eia == '63699'], aes(x = fit_net_gen_raw,y = fit_net_generation_mwh_raw)) + 
geom_point() + 
geom_abline(slope = 1, intercept = 0)


model_tables$model_dt_long[plant_id_eia == '63699']
model_tables$model_dt[plant_id_eia == '63699']

# First load for cal: 17161.30
# Fit Net Gen Raw =   0.3134875 + 0.0001151731 * 17161.30 + -8.973444e-10 * 17161.30^2